import random
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# Reproducibility
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# ------------------------------------------------------------
# 1. VOCABULARY & TEMPLATES (Clean, compositional)
# ------------------------------------------------------------
ADJECTIVES = ["big", "small", "black", "white", "old", "young", "lazy", "quick"]
NOUNS_A    = ["cat", "bird", "fish", "fox"]     # relevant nouns
NOUNS_B    = ["dog", "horse", "bear", "wolf"]   # distractor nouns
VERBS_A    = ["sat", "slept", "hid", "rested"]
VERBS_B    = ["ran", "jumped", "played", "howled"]
PREPS_A    = ["on", "under", "near", "beside"]
PREPS_B    = ["in", "through", "past", "over"]
LOCS_A     = ["mat", "chair", "sofa", "rug"]
LOCS_B     = ["park", "river", "road", "field"]

FUNCTION = ["the", "a", "an", "one", "and", "quietly", "quickly"]
SPECIAL  = ["<PAD>", "<SOS>", "<EOS>", "<UNK>"]

vocab_list = SPECIAL + FUNCTION + ADJECTIVES + NOUNS_A + NOUNS_B + VERBS_A + VERBS_B + PREPS_A + PREPS_B + LOCS_A + LOCS_B
# deduplicate
seen = set()
VOCAB = []
for w in vocab_list:
    if w not in seen:
        VOCAB.append(w)
        seen.add(w)

word2idx = {w: i for i, w in enumerate(VOCAB)}
idx2word = {i: w for w, i in word2idx.items()}
VOCAB_SIZE = len(VOCAB)
PAD_IDX = word2idx["<PAD>"]
SOS_IDX = word2idx["<SOS>"]
EOS_IDX = word2idx["<EOS>"]
UNK_IDX = word2idx["<UNK>"]

print(f"Vocabulary size: {VOCAB_SIZE}")

# Sentence templates with compositional slots
RELEVANT_TEMPLATES = [
    ["the", "NOUN_A", "VERB_A", "PREP_A", "the", "LOC_A"],
    ["a",  "ADJ",   "NOUN_A", "VERB_A", "PREP_A", "a", "LOC_A"],
    ["the", "ADJ",  "NOUN_A", "VERB_A", "PREP_A", "the", "LOC_A"],
    ["a",  "NOUN_A", "VERB_A", "PREP_A", "a", "LOC_A"],
    ["the", "NOUN_A", "and", "the", "LOC_A", "VERB_A"],
    ["a",  "ADJ",   "NOUN_A", "quietly", "VERB_A", "PREP_A", "the", "LOC_A"],
    ["the", "NOUN_A", "VERB_A", "PREP_A", "a", "ADJ", "LOC_A"],
    ["a",  "NOUN_A", "VERB_A", "beside", "the", "LOC_A"],
    ["the", "ADJ",  "NOUN_A", "VERB_A", "on", "the", "LOC_A"],
    ["one","NOUN_A", "VERB_A", "PREP_A", "the", "LOC_A"],
]

DISTRACTOR_TEMPLATES = [
    ["the", "NOUN_B", "VERB_B", "PREP_B", "the", "LOC_B"],
    ["a",  "ADJ",   "NOUN_B", "VERB_B", "PREP_B", "a", "LOC_B"],
    ["the", "ADJ",  "NOUN_B", "VERB_B", "PREP_B", "the", "LOC_B"],
    ["a",  "NOUN_B", "VERB_B", "PREP_B", "a", "LOC_B"],
    ["the", "NOUN_B", "and", "the", "LOC_B", "VERB_B"],
    ["a",  "ADJ",   "NOUN_B", "quickly", "VERB_B", "PREP_B", "the", "LOC_B"],
    ["the", "NOUN_B", "VERB_B", "PREP_B", "a", "ADJ", "LOC_B"],
    ["a",  "NOUN_B", "VERB_B", "through", "the", "LOC_B"],
    ["the", "ADJ",  "NOUN_B", "VERB_B", "in", "the", "LOC_B"],
    ["one","NOUN_B", "VERB_B", "PREP_B", "the", "LOC_B"],
]

def fill_template(template, is_relevant):
    words = []
    for tok in template:
        if tok == "ADJ":        words.append(random.choice(ADJECTIVES))
        elif tok == "NOUN_A":   words.append(random.choice(NOUNS_A))
        elif tok == "NOUN_B":   words.append(random.choice(NOUNS_B))
        elif tok == "VERB_A":   words.append(random.choice(VERBS_A))
        elif tok == "VERB_B":   words.append(random.choice(VERBS_B))
        elif tok == "PREP_A":   words.append(random.choice(PREPS_A))
        elif tok == "PREP_B":   words.append(random.choice(PREPS_B))
        elif tok == "LOC_A":    words.append(random.choice(LOCS_A))
        elif tok == "LOC_B":    words.append(random.choice(LOCS_B))
        else:                   words.append(tok)
    return words

def make_goal(sent_tokens, is_relevant):
    if is_relevant:
        nouns = [t for t in sent_tokens if t in NOUNS_A]
        preps = [t for t in sent_tokens if t in PREPS_A]
        locs  = [t for t in sent_tokens if t in LOCS_A]
    else:
        nouns = [t for t in sent_tokens if t in NOUNS_B]
        preps = [t for t in sent_tokens if t in PREPS_B]
        locs  = [t for t in sent_tokens if t in LOCS_B]
    noun = nouns[0] if nouns else "cat"
    prep = preps[0] if preps else "on"
    loc  = locs[0]  if locs  else "mat"
    return [noun, prep, loc]

# ------------------------------------------------------------
# 2. DATASET
# ------------------------------------------------------------
def generate_dataset(n_per_type=10000):
    data = []
    for _ in range(n_per_type):
        tmpl = random.choice(RELEVANT_TEMPLATES)
        sent = fill_template(tmpl, is_relevant=True)
        goal = make_goal(sent, is_relevant=True)
        data.append((goal, sent, True))
    for _ in range(n_per_type):
        tmpl = random.choice(DISTRACTOR_TEMPLATES)
        sent = fill_template(tmpl, is_relevant=False)
        goal = make_goal(sent, is_relevant=False)
        data.append((goal, sent, False))
    random.shuffle(data)
    return data

data = generate_dataset(10000)
print(f"Training samples: {len(data)}")

# PyTorch datasets
class SentenceOnlyDataset(Dataset):
    def __init__(self, raw):
        self.items = [torch.tensor([SOS_IDX] + [word2idx.get(w, UNK_IDX) for w in s] + [EOS_IDX], dtype=torch.long)
                      for _, s, _ in raw]
    def __len__(self): return len(self.items)
    def __getitem__(self, i): return self.items[i]

class GoalSentenceDataset(Dataset):
    def __init__(self, raw):
        self.items = []
        for g, s, _ in raw:
            g_ids = [SOS_IDX] + [word2idx.get(w, UNK_IDX) for w in g] + [EOS_IDX]
            s_ids = [SOS_IDX] + [word2idx.get(w, UNK_IDX) for w in s] + [EOS_IDX]
            self.items.append((torch.tensor(g_ids, dtype=torch.long),
                               torch.tensor(s_ids, dtype=torch.long)))
    def __len__(self): return len(self.items)
    def __getitem__(self, i): return self.items[i]

def collate_sent(batch): return pad_sequence(batch, batch_first=True, padding_value=PAD_IDX)
def collate_pair(batch):
    goals, sents = zip(*batch)
    return (pad_sequence(goals, batch_first=True, padding_value=PAD_IDX),
            pad_sequence(sents, batch_first=True, padding_value=PAD_IDX))

sent_ds = SentenceOnlyDataset(data)
pair_ds = GoalSentenceDataset(data)
sent_loader = DataLoader(sent_ds, batch_size=64, shuffle=True, collate_fn=collate_sent)
pair_loader = DataLoader(pair_ds, batch_size=64, shuffle=True, collate_fn=collate_pair)

# ------------------------------------------------------------
# 3. MODELS (small, correct architecture)
# ------------------------------------------------------------
class UniLSTM(nn.Module):
    def __init__(self, embed_dim=64, hidden_dim=128, n_layers=2, dropout=0.3):
        super().__init__()
        self.emb = nn.Embedding(VOCAB_SIZE, embed_dim, padding_idx=PAD_IDX)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, n_layers, batch_first=True, dropout=dropout)
        self.fc  = nn.Linear(hidden_dim, VOCAB_SIZE)
    def forward(self, x, hidden=None):
        out, hid = self.lstm(self.emb(x), hidden)
        return self.fc(out), hid

class BidirEncoder(nn.Module):
    def __init__(self, embed_dim=64, hidden_dim=128, n_layers=2, dropout=0.3):
        super().__init__()
        self.emb = nn.Embedding(VOCAB_SIZE, embed_dim, padding_idx=PAD_IDX)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, n_layers, batch_first=True, bidirectional=True, dropout=dropout)
        self.fc_h = nn.Linear(hidden_dim*2, hidden_dim)
        self.fc_c = nn.Linear(hidden_dim*2, hidden_dim)
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
    def forward(self, src):
        B = src.size(0)
        _, (h, c) = self.lstm(self.emb(src))
        h = h.view(self.n_layers, 2, B, self.hidden_dim)
        c = c.view(self.n_layers, 2, B, self.hidden_dim)
        h = torch.cat([h[:,0], h[:,1]], dim=-1)
        c = torch.cat([c[:,0], c[:,1]], dim=-1)
        return torch.tanh(self.fc_h(h)), torch.tanh(self.fc_c(c))

class UniDecoder(nn.Module):
    def __init__(self, embed_dim=64, hidden_dim=128, n_layers=2, dropout=0.3):
        super().__init__()
        self.emb = nn.Embedding(VOCAB_SIZE, embed_dim, padding_idx=PAD_IDX)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, n_layers, batch_first=True, dropout=dropout)
        self.fc  = nn.Linear(hidden_dim, VOCAB_SIZE)
    def forward(self, x, hidden):
        out, hid = self.lstm(self.emb(x), hidden)
        return self.fc(out), hid

class Seq2SeqLSTM(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = BidirEncoder()
        self.decoder = UniDecoder()
    def forward(self, goal, tgt_in):
        h, c = self.encoder(goal)
        logits, _ = self.decoder(tgt_in, (h, c))
        return logits

class TransformerS2S(nn.Module):
    def __init__(self, d_model=128, nhead=4, num_enc=2, num_dec=2, dropout=0.1):
        super().__init__()
        self.src_emb = nn.Embedding(VOCAB_SIZE, d_model, padding_idx=PAD_IDX)
        self.tgt_emb = nn.Embedding(VOCAB_SIZE, d_model, padding_idx=PAD_IDX)
        self.pe = self._positional_encoding(d_model, 200)
        self.transformer = nn.Transformer(d_model, nhead, num_enc, num_dec, dim_feedforward=256, dropout=dropout, batch_first=True)
        self.fc = nn.Linear(d_model, VOCAB_SIZE)
        self.scale = math.sqrt(d_model)
    def _positional_encoding(self, d_model, max_len):
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        return pe.unsqueeze(0)
    def forward(self, src, tgt, src_pad=None, tgt_pad=None):
        L = tgt.size(1)
        causal = nn.Transformer.generate_square_subsequent_mask(L, device=src.device)
        src_pe = self.src_emb(src) * self.scale + self.pe[:, :src.size(1), :].to(src.device)
        tgt_pe = self.tgt_emb(tgt) * self.scale + self.pe[:, :tgt.size(1), :].to(tgt.device)
        out = self.transformer(src_pe, tgt_pe, tgt_mask=causal,
                               src_key_padding_mask=src_pad,
                               tgt_key_padding_mask=tgt_pad,
                               memory_key_padding_mask=src_pad)
        return self.fc(out)

# ------------------------------------------------------------
# 4. TRAINING
# ------------------------------------------------------------
def train_uni(model, loader, epochs=10, lr=1e-3):
    model.to(device)
    opt = optim.Adam(model.parameters(), lr=lr)
    crit = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
    for ep in range(1, epochs+1):
        model.train(); total = 0
        for b in loader:
            b = b.to(device)
            src, tgt = b[:,:-1], b[:,1:]
            logits, _ = model(src)
            loss = crit(logits.reshape(-1, VOCAB_SIZE), tgt.reshape(-1))
            opt.zero_grad(); loss.backward(); opt.step()
            total += loss.item()
        print(f"  UniLSTM Ep {ep:02d} | loss {total/len(loader):.4f}")

def train_seq2seq(model, loader, epochs=10, lr=1e-3, tag="Seq2SeqLSTM"):
    model.to(device)
    opt = optim.Adam(model.parameters(), lr=lr)
    crit = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
    for ep in range(1, epochs+1):
        model.train(); total = 0
        for goals, sents in loader:
            goals, sents = goals.to(device), sents.to(device)
            tin, tout = sents[:,:-1], sents[:,1:]
            logits = model(goals, tin)
            loss = crit(logits.reshape(-1, VOCAB_SIZE), tout.reshape(-1))
            opt.zero_grad(); loss.backward(); opt.step()
            total += loss.item()
        print(f"  {tag} Ep {ep:02d} | loss {total/len(loader):.4f}")

def train_transformer(model, loader, epochs=10, lr=5e-4):
    model.to(device)
    opt = optim.Adam(model.parameters(), lr=lr, betas=(0.9,0.98), eps=1e-9)
    crit = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
    for ep in range(1, epochs+1):
        model.train(); total = 0
        for goals, sents in loader:
            goals, sents = goals.to(device), sents.to(device)
            tin, tout = sents[:,:-1], sents[:,1:]
            sp = (goals == PAD_IDX); tp = (tin == PAD_IDX)
            logits = model(goals, tin, sp, tp)
            loss = crit(logits.reshape(-1, VOCAB_SIZE), tout.reshape(-1))
            opt.zero_grad(); loss.backward(); opt.step()
            total += loss.item()
        print(f"  Transformer Ep {ep:02d} | loss {total/len(loader):.4f}")

# ------------------------------------------------------------
# 5. GENERATION & EVALUATION
# ------------------------------------------------------------
@torch.no_grad()
def gen_uni(model, seed="the", max_len=15):
    model.eval()
    idx = word2idx.get(seed, UNK_IDX)
    x = torch.tensor([[SOS_IDX, idx]], dtype=torch.long, device=device)
    logits, hid = model(x)
    nxt = logits[0,-1].argmax().item()
    out = [nxt]
    x = torch.tensor([[nxt]], dtype=torch.long, device=device)
    for _ in range(max_len-1):
        logits, hid = model(x, hid)
        nxt = logits[0,-1].argmax().item()
        if nxt == EOS_IDX: break
        out.append(nxt)
        x = torch.tensor([[nxt]], dtype=torch.long, device=device)
    return [idx2word[i] for i in out if i not in (PAD_IDX, SOS_IDX)]

@torch.no_grad()
def gen_s2s(model, goal_tokens, max_len=15):
    model.eval()
    gids = [SOS_IDX] + [word2idx.get(w, UNK_IDX) for w in goal_tokens] + [EOS_IDX]
    gids = torch.tensor([gids], dtype=torch.long, device=device)
    h, c = model.encoder(gids)
    x = torch.tensor([[SOS_IDX]], dtype=torch.long, device=device)
    out = []
    for _ in range(max_len):
        logits, (h, c) = model.decoder(x, (h, c))
        nxt = logits[0,-1].argmax().item()
        if nxt == EOS_IDX: break
        out.append(nxt)
        x = torch.tensor([[nxt]], dtype=torch.long, device=device)
    return [idx2word[i] for i in out if i not in (PAD_IDX, SOS_IDX)]

@torch.no_grad()
def gen_tf(model, goal_tokens, max_len=15):
    model.eval()
    gids = [SOS_IDX] + [word2idx.get(w, UNK_IDX) for w in goal_tokens] + [EOS_IDX]
    gids = torch.tensor([gids], dtype=torch.long, device=device)
    gen = [SOS_IDX]
    for _ in range(max_len):
        tgt = torch.tensor([gen], dtype=torch.long, device=device)
        logits = model(gids, tgt)
        nxt = logits[0,-1].argmax().item()
        if nxt == EOS_IDX: break
        gen.append(nxt)
    return [idx2word[i] for i in gen[1:] if i not in (PAD_IDX, SOS_IDX)]

def relevance(sentence_tokens, goal_tokens):
    words = sentence_tokens
    if goal_tokens[0] in words and goal_tokens[-1] in words:
        i1 = words.index(goal_tokens[0])
        i2 = words.index(goal_tokens[-1])
        if i1 < i2 and (i2 - i1) <= 4:
            return True
    return False

# ------------------------------------------------------------
# 6. MAIN
# ------------------------------------------------------------
print("\n=== Training UniLSTM ===")
uni = UniLSTM()
train_uni(uni, sent_loader)

print("\n=== Training Seq2Seq LSTM ===")
s2s = Seq2SeqLSTM()
train_seq2seq(s2s, pair_loader, tag="Seq2SeqLSTM")

print("\n=== Training Transformer ===")
tf = TransformerS2S()
train_transformer(tf, pair_loader)

# Novel compositional test goals (unseen combinations)
test_goals = [
    ["fox", "under", "sofa"],
    ["bird", "beside", "rug"],
    ["fish", "near", "chair"],
    ["cat", "on", "rug"],
    ["fox", "near", "mat"],
    ["bird", "on", "sofa"],
    ["fish", "under", "chair"],
    ["cat", "beside", "sofa"],
    ["fox", "on", "chair"],
    ["bird", "under", "mat"],
]
N_TEST = 1000

uni_rel = 0
s2s_rel = 0
tf_rel  = 0

print("\n=== Evaluating ===")
for i in range(N_TEST):
    goal = random.choice(test_goals)
    uni_sent = gen_uni(uni)
    s2s_sent = gen_s2s(s2s, goal)
    tf_sent  = gen_tf(tf, goal)
    uni_rel += relevance(uni_sent, goal)
    s2s_rel += relevance(s2s_sent, goal)
    tf_rel  += relevance(tf_sent, goal)

print("\n=== RESULTS ===")
print(f"UniLSTM Relevance:       {uni_rel/N_TEST*100:.1f}%")
print(f"Seq2Seq LSTM Relevance:  {s2s_rel/N_TEST*100:.1f}%")
print(f"Transformer Relevance:   {tf_rel/N_TEST*100:.1f}%")

print("\n=== SAMPLE GENERATIONS ===")
print("\nUniLSTM:")
for _ in range(3):
    print("  ", " ".join(gen_uni(uni)))
print("\nSeq2Seq LSTM:")
for _ in range(3):
    goal = random.choice(test_goals)
    print(f"  Goal: {' '.join(goal)}")
    print(f"  Gen:  {' '.join(gen_s2s(s2s, goal))}")
print("\nTransformer:")
for _ in range(3):
    goal = random.choice(test_goals)
    print(f"  Goal: {' '.join(goal)}")
    print(f"  Gen:  {' '.join(gen_tf(tf, goal))}")
---------------------------------------------------------------
Device: cuda
Vocabulary size: 51
Training samples: 20000

=== Training UniLSTM ===
  UniLSTM Ep 01 | loss 1.7808
  UniLSTM Ep 02 | loss 1.2196
  UniLSTM Ep 03 | loss 1.2111
  UniLSTM Ep 04 | loss 1.2075
  UniLSTM Ep 05 | loss 1.2068
  UniLSTM Ep 06 | loss 1.2055
  UniLSTM Ep 07 | loss 1.2048
  UniLSTM Ep 08 | loss 1.2044
  UniLSTM Ep 09 | loss 1.2041
  UniLSTM Ep 10 | loss 1.2032

=== Training Seq2Seq LSTM ===
  Seq2SeqLSTM Ep 01 | loss 1.6258
  Seq2SeqLSTM Ep 02 | loss 0.7927
  Seq2SeqLSTM Ep 03 | loss 0.6322
  Seq2SeqLSTM Ep 04 | loss 0.5700
  Seq2SeqLSTM Ep 05 | loss 0.5632
  Seq2SeqLSTM Ep 06 | loss 0.5604
  Seq2SeqLSTM Ep 07 | loss 0.5585
  Seq2SeqLSTM Ep 08 | loss 0.5570
  Seq2SeqLSTM Ep 09 | loss 0.5561
  Seq2SeqLSTM Ep 10 | loss 0.5558

=== Training Transformer ===
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/activation.py:1336: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.
  key_padding_mask = F._canonical_mask(
  Transformer Ep 01 | loss 1.0039
  Transformer Ep 02 | loss 0.5898
  Transformer Ep 03 | loss 0.5760
  Transformer Ep 04 | loss 0.5707
  Transformer Ep 05 | loss 0.5678
  Transformer Ep 06 | loss 0.5658
  Transformer Ep 07 | loss 0.5644
  Transformer Ep 08 | loss 0.5631
  Transformer Ep 09 | loss 0.5624
  Transformer Ep 10 | loss 0.5621

=== Evaluating ===

=== RESULTS ===
UniLSTM Relevance:       0.0%
Seq2Seq LSTM Relevance:  100.0%
Transformer Relevance:   100.0%

=== SAMPLE GENERATIONS ===

UniLSTM:
   bear and the park jumped
   bear and the park jumped
   bear and the park jumped

Seq2Seq LSTM:
  Goal: fox near mat
  Gen:  a fox slept near a mat
  Goal: bird under mat
  Gen:  a bird slept under a mat
  Goal: fish under chair
  Gen:  a fish rested under a chair

Transformer:
  Goal: bird beside rug
  Gen:  a bird slept beside the rug
  Goal: fox under sofa
  Gen:  the fox slept under the sofa
  Goal: cat on rug
  Gen:  the cat and the rug rested



